{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pydrake.all import BezierCurve, MultilayerPerceptron, RandomGenerator, PerceptronActivationType\n",
    "\n",
    "import matplotlib.pyplot as plt, mpld3\n",
    "import numpy as np\n",
    "\n",
    "from manipulation import running_as_notebook\n",
    "\n",
    "if running_as_notebook:\n",
    "    mpld3.enable_notebook()\n",
    "\n",
    "rng = np.random.default_rng()\n",
    "generator = RandomGenerator()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Denoising diffusion as simple projection onto a manifold."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note: use domain [-1, 1]\n",
    "traj = BezierCurve(\n",
    "    0, 1, 2*np.mat(\"0.1, 0.2; 0.4, 0.2; 0.5, 0.9; 0.7, 0.5; 0.9, 0.9\").T -1\n",
    ")\n",
    "\n",
    "def plot_trajectory(traj):\n",
    "    plt.figure()\n",
    "    plt.axis(\"square\")\n",
    "    plt.xlim([-1, 1])\n",
    "    plt.ylim([-1, 1])\n",
    "    ts = np.linspace(0, 1, 100)\n",
    "    xs = traj.vector_values(ts)\n",
    "    plt.plot(xs[0, :], xs[1, :])\n",
    "\n",
    "\n",
    "plot_trajectory(traj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# inputs: x_0[k], x_1[k], sigma_k\n",
    "# outputs: x_0[k+1], x_1[k+1]\n",
    "denoiser = MultilayerPerceptron([4, 255, 255, 2]) #, [PerceptronActivationType.kTanh, PerceptronActivationType.kTanh, PerceptronActivationType.kIdentity])\n",
    "context = denoiser.CreateDefaultContext()\n",
    "denoiser.SetRandomContext(context, generator)\n",
    "params = denoiser.GetMutableParameters(context)\n",
    "dloss_dparams = 0 * params\n",
    "\n",
    "# discrete noise levels (following Permenter23, Section 3.3.2)\n",
    "K = 100\n",
    "beta = 0.05  # choose 0 <= beta < 1\n",
    "sigma = [1]\n",
    "for i in range(K - 1, 0, -1):\n",
    "    sigma.append((1 - beta) * sigma[-1])\n",
    "sigma.reverse()\n",
    "sigma = np.asarray(sigma)\n",
    "print(f\"sigma ∈ [{sigma.min()}, {sigma.max()}]\")\n",
    "\n",
    "# learning rate\n",
    "eta = 0.001\n",
    "\n",
    "err = []\n",
    "for iter in range(500):\n",
    "    N = 1000\n",
    "    # generate N random samples from the manifold\n",
    "    xs = traj.vector_values(rng.uniform(0, 1, N))\n",
    "    # generate N random noise values\n",
    "    eps = rng.normal(0, 1, (2, N))\n",
    "    # generate N random noise levels\n",
    "    ks = rng.integers(0, K, N)\n",
    "    sigmas = sigma[ks]\n",
    "\n",
    "    X = np.vstack((xs + sigmas * eps, np.sin(ks), np.cos(ks)))\n",
    "    # plt.plot(np.vstack((X[0,:200], xs[0,:200])),np.vstack((X[1,:200], xs[1,:200])),color=str(0.5))\n",
    "\n",
    "    # Take a few steps of gradient descent.\n",
    "    # TODO(russt): Use ADAM.\n",
    "    this_err = 0\n",
    "    for step in range(1):\n",
    "        this_err += denoiser.BackpropagationMeanSquaredError(\n",
    "            context, X=X, Y_desired=eps, dloss_dparams=dloss_dparams\n",
    "        )\n",
    "        params -= eta * dloss_dparams\n",
    "    err.append(this_err)\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(err)\n",
    "plt.xlabel(\"iter\")\n",
    "plt.ylabel(\"loss\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_trajectory(traj)\n",
    "\n",
    "# Plot the denoiser \"projection\" back to the manifold for a particular noise level\n",
    "k_plot = 50\n",
    "assert k_plot >= 0 and k_plot < K\n",
    "x0s, x1s = np.meshgrid(\n",
    "    np.linspace(-1, 1, 11), np.linspace(-1, 1, 11), indexing=\"ij\"\n",
    ")\n",
    "X = np.vstack(\n",
    "    (x0s.flatten(), x1s.flatten(), 0 * x0s.flatten() + np.sin(k_plot), 0 * x0s.flatten() + np.cos(k_plot))\n",
    ")\n",
    "Y = np.asfortranarray(np.zeros((2, X.shape[1])))\n",
    "denoiser.BatchOutput(context, X, Y)\n",
    "plt.plot(\n",
    "    np.vstack((X[0, :], X[0, :] - sigma[k_plot] * Y[0, :])),\n",
    "    np.vstack((X[1, :], X[1, :] - sigma[k_plot] * Y[1, :])),\n",
    "    color=str(0.5),\n",
    ");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_trajectory(traj)\n",
    "\n",
    "for i in range(20):\n",
    "    # Denoise a random initial sample\n",
    "    xk = rng.normal(0, 1, [2, 1])\n",
    "    plt.plot(xk[0], xk[1], \"x\", color=str(0.5))\n",
    "    for k in range(K - 1, 0, -1):\n",
    "        for j in range(2):\n",
    "            denoiser.get_input_port().FixValue(context, np.vstack((xk, np.sin(k), np.cos(k))))\n",
    "            eps = denoiser.get_output_port().Eval(context).reshape([2, 1])\n",
    "            xk = xk + (sigma[k - 1] - sigma[k]) * eps\n",
    "        plt.plot(xk[0], xk[1], \"x\", color=str(0.5 * k / K))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
